{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dd522981",
   "metadata": {},
   "source": [
    "# Linear quadratic optimal control example\n",
    "\n",
    "Richard M. Murray, 20 Jan 2022 (updated 7 Jul 2024)\n",
    "\n",
    "This example works through the linear quadratic finite time optimal control problem.  We assume that we have a linear system of the form\n",
    "$$\n",
    "\\dot x = A x + Bu \n",
    "$$\n",
    "and that we want to minimize a cost function of the form\n",
    "$$\n",
    "\\int_0^T (x^T Q_x x + u^T Q_u u) dt + x^T P_1 x.\n",
    "$$\n",
    "We show how to compute the solution the Riccati ODE and use this to obtain an optimal (time-varying) linear controller."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "866842ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy as sp\n",
    "import matplotlib.pyplot as plt\n",
    "import control as ct\n",
    "import control.optimal as opt\n",
    "import control.flatsys as fs\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83a32e85",
   "metadata": {},
   "source": [
    "## System dynamics\n",
    "\n",
    "We use the linearized dynamics of the vehicle steering problem as our linear system.  This is mainly for convenient (since we have some intuition about it).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c1bd7f-0db6-4488-af41-41f685280ec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Vehicle dynamics (bicycle model)\n",
    "\n",
    "# Function to take states, inputs and return the flat flag\n",
    "def _kincar_flat_forward(x, u, params={}):\n",
    "    # Get the parameter values\n",
    "    b = params.get('wheelbase', 3.)\n",
    "    #! TODO: add dir processing\n",
    "\n",
    "    # Create a list of arrays to store the flat output and its derivatives\n",
    "    zflag = [np.zeros(3), np.zeros(3)]\n",
    "\n",
    "    # Flat output is the x, y position of the rear wheels\n",
    "    zflag[0][0] = x[0]\n",
    "    zflag[1][0] = x[1]\n",
    "\n",
    "    # First derivatives of the flat output\n",
    "    zflag[0][1] = u[0] * np.cos(x[2])  # dx/dt\n",
    "    zflag[1][1] = u[0] * np.sin(x[2])  # dy/dt\n",
    "\n",
    "    # First derivative of the angle\n",
    "    thdot = (u[0]/b) * np.tan(u[1])\n",
    "\n",
    "    # Second derivatives of the flat output (setting vdot = 0)\n",
    "    zflag[0][2] = -u[0] * thdot * np.sin(x[2])\n",
    "    zflag[1][2] =  u[0] * thdot * np.cos(x[2])\n",
    "\n",
    "    return zflag\n",
    "\n",
    "# Function to take the flat flag and return states, inputs\n",
    "def _kincar_flat_reverse(zflag, params={}):\n",
    "    # Get the parameter values\n",
    "    b = params.get('wheelbase', 3.)\n",
    "    dir = params.get('dir', 'f')\n",
    "\n",
    "    # Create a vector to store the state and inputs\n",
    "    x = np.zeros(3)\n",
    "    u = np.zeros(2)\n",
    "\n",
    "    # Given the flat variables, solve for the state\n",
    "    x[0] = zflag[0][0]  # x position\n",
    "    x[1] = zflag[1][0]  # y position\n",
    "    if dir == 'f':\n",
    "        x[2] = np.arctan2(zflag[1][1], zflag[0][1])  # tan(theta) = ydot/xdot\n",
    "    elif dir == 'r':\n",
    "        # Angle is flipped by 180 degrees (since v < 0)\n",
    "        x[2] = np.arctan2(-zflag[1][1], -zflag[0][1])\n",
    "    else:\n",
    "        raise ValueError(\"unknown direction:\", dir)\n",
    "\n",
    "    # And next solve for the inputs\n",
    "    u[0] = zflag[0][1] * np.cos(x[2]) + zflag[1][1] * np.sin(x[2])\n",
    "    thdot_v = zflag[1][2] * np.cos(x[2]) - zflag[0][2] * np.sin(x[2])\n",
    "    u[1] = np.arctan2(thdot_v, u[0]**2 / b)\n",
    "\n",
    "    return x, u\n",
    "\n",
    "# Function to compute the RHS of the system dynamics\n",
    "def _kincar_update(t, x, u, params):\n",
    "    b = params.get('wheelbase', 3.)             # get parameter values\n",
    "    #! TODO: add dir processing\n",
    "    dx = np.array([\n",
    "        np.cos(x[2]) * u[0],\n",
    "        np.sin(x[2]) * u[0],\n",
    "        (u[0]/b) * np.tan(u[1])\n",
    "    ])\n",
    "    return dx\n",
    "\n",
    "def _kincar_output(t, x, u, params):\n",
    "    return x                            # return x, y, theta (full state)\n",
    "\n",
    "# Create differentially flat input/output system\n",
    "kincar = fs.FlatSystem(\n",
    "    _kincar_flat_forward, _kincar_flat_reverse, name=\"kincar\",\n",
    "    updfcn=_kincar_update, outfcn=_kincar_output,\n",
    "    inputs=('v', 'delta'), outputs=('x', 'y', 'theta'),\n",
    "    states=('x', 'y', 'theta'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbdd78c0-30e9-43f7-9e8d-198ae38c2988",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility function to plot lane change manuever\n",
    "def plot_lanechange(t, y, u, figure=None, yf=None):\n",
    "    # Plot the xy trajectory\n",
    "    plt.subplot(3, 1, 1, label='xy')\n",
    "    plt.plot(y[0], y[1])\n",
    "    plt.xlabel(\"x [m]\")\n",
    "    plt.ylabel(\"y [m]\")\n",
    "    if yf is not None:\n",
    "        plt.plot(yf[0], yf[1], 'ro')\n",
    "\n",
    "    # Plot the inputs as a function of time\n",
    "    plt.subplot(3, 1, 2, label='v')\n",
    "    plt.plot(t, u[0])\n",
    "    plt.xlabel(\"Time $t$ [sec]\")\n",
    "    plt.ylabel(\"$v$ [m/s]\")\n",
    "\n",
    "    plt.subplot(3, 1, 3, label='delta')\n",
    "    plt.plot(t, u[1])\n",
    "    plt.xlabel(\"Time $t$ [sec]\")\n",
    "    plt.ylabel(\"$\\\\delta$ [rad]\")\n",
    "\n",
    "    plt.suptitle(\"Lane change manuever\")\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de9d85f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initial conditions\n",
    "x0 = np.array([-40, -2., 0.])\n",
    "u0 = np.array([10, 0])               # only used for linearization\n",
    "Tf = 4\n",
    "\n",
    "# Linearized dynamics\n",
    "sys = kincar.linearize(x0, u0)\n",
    "print(sys)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5c0abe9",
   "metadata": {},
   "source": [
    "## Optimal trajectory generation\n",
    "\n",
    "We generate an trajectory for the system that minimizes the cost function above.  Namely, starting from some initial function $x(0) = x_0$, we wish to bring the system toward the origin without using too much control effort."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02e9f87c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the cost function and the terminal cost\n",
    "# (try changing these later to see what happens)\n",
    "Qx = np.diag([1, 1, 1])       # state costs\n",
    "Qu = np.diag([1, 1])          # input costs\n",
    "Pf = np.diag([1, 1, 1])       # terminal costs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62c76e5e",
   "metadata": {},
   "source": [
    "### Finite time, linear quadratic optimization\n",
    "\n",
    "The optimal solution satisfies the following equations, which follow from the maximum principle:\n",
    "\n",
    "$$\n",
    "  \\begin{aligned}\n",
    "    \\dot x &= \\left(\\frac{\\partial H}{\\partial \\lambda}\\right)^T\n",
    "      = A x + Bu, \\qquad & x(0) &= x_0, \\\\\n",
    "    -\\dot \\lambda &= \\left(\\frac{\\partial H}{\\partial x}\\right)^T\n",
    "      = Q_x x + A^T \\lambda, \\qquad\n",
    "      & \\lambda(T) &= P_1 x(T), \\\\\n",
    "    0 &= \\left(\\frac{\\partial H}{\\partial u}\\right)^T\n",
    "      = Q_u u + B^T \\lambda. &&\n",
    "  \\end{aligned}\n",
    "$$\n",
    "\n",
    "The last condition can be solved to obtain the optimal controller\n",
    "\n",
    "$$\n",
    "  u = -Q_u^{-1} B^T \\lambda,\n",
    "$$\n",
    "\n",
    "which can be substituted into the equations for the optimal solution.\n",
    "\n",
    "Given the linear nature of the dynamics, we attempt to find a solution\n",
    "by setting $\\lambda(t) = P(t) x(t)$ where $P(t) \\in {\\mathbb R}^{n \\times\n",
    "n}$.  Substituting this into the necessary condition, we obtain\n",
    "\n",
    "$$\n",
    "  \\begin{aligned}\n",
    "    & \\dot\\lambda =\n",
    "      \\dot P x + P \\dot x = \\dot P x + P(Ax - BQ_u^{-1} B^T P) x, \\\\\n",
    "    & \\quad\\implies\\quad\n",
    "      -\\dot P x - PA x + PBQ_u^{-1}B P x = Q_xx + A^T P x.\n",
    "  \\end{aligned}\n",
    "$$\n",
    "\n",
    "This equation is satisfied if we can find $P(t)$ such that\n",
    "\n",
    "$$\n",
    "  -\\dot P = PA + A^T P - P B Q_u^{-1} B^T P + Q_x,\n",
    "  \\qquad P(T) = P_1.\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b63aed88",
   "metadata": {},
   "source": [
    "To solve a final value problem with $P(T) = P_1$, we set the \"initial\" condition to $P_1$ and then invert time, so that we solve\n",
    "\n",
    "$$\n",
    "\\frac{dP}{d(-t)} = -\\frac{dP}{dt} = -F(P), \\qquad P(0) = P_1\n",
    "$$\n",
    "\n",
    "Solving this equation from time $t = 0$ to time $t = T$ will give us an solution that goes from $P(T)$ to $P(0)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d74789",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the Riccatti ODE\n",
    "def Pdot_reverse(t, x):\n",
    "    # Get the P matrix from the state by resizing\n",
    "    P = np.reshape(x, (sys.nstates, sys.nstates))\n",
    "    \n",
    "    # Compute the right hand side of Riccati ODE\n",
    "    Prhs = P @ sys.A + sys.A.T @ P + Qx - \\\n",
    "        P @ sys.B @ np.linalg.inv(Qu) @ sys.B.T @ P\n",
    "        \n",
    "    # Return P as a vector, *backwards* in time (no minus sign)\n",
    "    return Prhs.reshape((-1))\n",
    "\n",
    "# Solve the Riccati ODE (converting from matrix to vector and back)\n",
    "P0 = np.reshape(Pf, (-1))\n",
    "Psol = sp.integrate.solve_ivp(Pdot_reverse, (0, Tf), P0)\n",
    "Pfwd = np.reshape(Psol.y, (sys.nstates, sys.nstates, -1))\n",
    "\n",
    "# Reorder the solution in time\n",
    "Prev = Pfwd[:, :, ::-1] \n",
    "trev = Tf - Psol.t[::-1]\n",
    "\n",
    "print(\"Trange = \", trev[0], \"to\", trev[-1])\n",
    "print(\"P[Tf] =\", Prev[:,:,-1])\n",
    "print(\"P[0] =\", Prev[:,:,0])\n",
    "\n",
    "# Internal comparison: show that initial value is close to algebraic solution\n",
    "_, P_lqr, _ = ct.lqr(sys.A, sys.B, Qx, Qu)\n",
    "print(\"P_lqr =\", P_lqr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4fb1166",
   "metadata": {},
   "source": [
    "For solving the $x$ dynamics, we need a function to evaluate $P(t)$ at an arbitrary time (used by the integrator).  We can do this with the SciPy `interp1d` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "728f675b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define an interpolation function for P\n",
    "P = sp.interpolate.interp1d(trev, Prev)\n",
    "\n",
    "print(\"P(0) =\", P(0))\n",
    "print(\"P(3.5) =\", P(3.5))\n",
    "print(\"P(4) =\", P(4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb30c3fa",
   "metadata": {},
   "source": [
    "We now solve the $\\dot x$ equations *forward* in time, using $P(t)$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84092dcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now solve the state forward in time\n",
    "def xdot_forward(t, x):\n",
    "    u = -np.linalg.inv(Qu) @ sys.B.T @ P(t) @ x\n",
    "    return sys.A @ x + sys.B @ u\n",
    "\n",
    "# Now simulate from a shifted initial condition\n",
    "xsol = sp.integrate.solve_ivp(xdot_forward, (0, Tf), x0)\n",
    "tvec = xsol.t\n",
    "x = xsol.y\n",
    "print(\"x[0] =\", x[:, 0])\n",
    "print(\"x[Tf] =\", x[:, -1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8488acad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally compute the \"desired\" state and input values\n",
    "xd = x\n",
    "ud = np.zeros((sys.ninputs, tvec.size))\n",
    "for i, t in enumerate(tvec):\n",
    "  ud[:, i] = -np.linalg.inv(Qu) @ sys.B.T @ P(t) @ x[:, i]\n",
    "\n",
    "plot_lanechange(tvec, xd, ud)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89483f4b",
   "metadata": {},
   "source": [
    "Note here that we are stabilizing the system to the origin (compared to some of other examples where we change langes and so the final $y$ position is $y_\\text{f} = 2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ed4c5eb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
